package com.spring.mvc;

import com.alibaba.fastjson.JSON;
import com.spring.mvc.annotation.*;
import org.dom4j.Attribute;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URL;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author yangjian
 */
public class DispatcherServlet extends HttpServlet {

	// Spring MVC 容器
	private ApplicationContext context;

	// 保存所有的 Bean class 文件路径
	private final List<String> classFiles = new ArrayList<>();

	private final Map<String /*path*/, RequestHandler/*Controller Handler*/> handlerMapping = new ConcurrentHashMap<>();

	@Override
	public void init() throws ServletException
	{
		context = new ApplicationContext();
		// 1、获取 Servlet 初始化参数
		String contextConfigLocation = getServletConfig().getInitParameter("contextConfigLocation");
		// 2、解析 XML 获取扫描包路径
		String xmlPath = contextConfigLocation.substring(contextConfigLocation.indexOf(":") + 1);
		ClassLoader classLoader = DispatcherServlet.class.getClassLoader();

		URL xmlResource = classLoader.getResource(xmlPath);
		assert xmlResource != null;
		File file = new File(xmlResource.getFile());
		List<String> packages = parseComponentPackage(file.getAbsolutePath());
		// 3、扫描 class 文件
		for (String path : packages) {
			path = path.replace(".", File.separator);
			File packFile = new File(Objects.requireNonNull(classLoader.getResource(path)).getFile());
			scanComponents(packFile.getAbsolutePath());
		}

		// 4、创建 bean，注册到 Bean 容器
		createBeans();

		// 5. 处理 Bean 的自动依赖注入
		initBeanAutowired();

		// 6. 初始化 HandlerMapping
		initHandlerMapping();
		super.init();
	}

	// 解析 @RequestMapping 注解，生成 RequestHandler 对象
	private void initHandlerMapping()
	{
		for (Object bean : context.getBeans().values()) {
			for (Method method : bean.getClass().getDeclaredMethods()) {
				if (method.isAnnotationPresent(RequestMapping.class)) {
					String path = method.getAnnotation(RequestMapping.class).value();
					String httpMethod = method.getAnnotation(RequestMapping.class).method();
					this.handlerMapping.put(path, new RequestHandler(path, bean, method, httpMethod));
				} else if (method.isAnnotationPresent(GetMapping.class)) {
					String path = method.getAnnotation(GetMapping.class).value();
					this.handlerMapping.put(path, new RequestHandler(path, bean, method, RequestMapping.GET));
				} else if (method.isAnnotationPresent(PostMapping.class)) {
					String path = method.getAnnotation(PostMapping.class).value();
					this.handlerMapping.put(path, new RequestHandler(path, bean, method, RequestMapping.POST));
				}

			}
		}
	}

	// 初始化 Bean，处理依赖自动注入
	private void initBeanAutowired()
	{
		for (Object bean : context.getBeans().values()) {
			// 遍历所有的字段，发现有 @Autowired 注解的就注入
			for (Field field : bean.getClass().getDeclaredFields()) {
				if (field.isAnnotationPresent(Autowired.class)) {
					String beanName = field.getAnnotation(Autowired.class).value();
					field.setAccessible(true);
					try {
						field.set(bean, context.getBean(beanName));
					} catch (IllegalAccessException e) {
						throw new RuntimeException(e);
					}
				}
			}
		}
	}

	// 创建 Bean，并注册到容器
	private void createBeans()
	{
		try {
			for (String className : this.classFiles) {
				Class<?> clazz = Class.forName(className);
				// 过滤掉接口
				if (clazz.isInterface()) {
					continue;
				}
				String beanName = null;
				// 控制器 Bean 实例
				if (clazz.isAnnotationPresent(Controller.class)) {
					beanName = getBeanNameByClassName(clazz.getSimpleName());
					// 服务 Bean 实例处理
				} else if (clazz.isAnnotationPresent(Service.class)) {
					Service annotation = clazz.getAnnotation(Service.class);
					beanName = annotation.value();
					// 如果注解值为空，则自动生成 beanName
					if (beanName.equals("")) {
						beanName = getBeanNameByClassName(clazz.getSimpleName());
					}
				}
				// 注册 bean 到容器
				if (null != beanName) {
					context.registerBean(beanName, clazz.getDeclaredConstructor().newInstance());
				}
			}
		} catch (Exception e) {
			// TODO： 处理异常
			e.printStackTrace();
		}
	}

	// 获取 Bean 名称
	public String getBeanNameByClassName(String className)
	{
		String prefix = className.substring(0, 1);
		String suffix = className.substring(1);
		return prefix.toLowerCase(Locale.ROOT) + suffix;
	}

	// 解析 spring-mvc.xml 文档，获取 spring mvc 组件扫描路径
	private List<String> parseComponentPackage(String xmlPath)
	{
		List<String> packages = new ArrayList<>();
		try {
			SAXReader saxReader = SAXReader.createDefault();
			// 加载 XML 文档树
			Document document = saxReader.read(xmlPath);
			Element root = document.getRootElement();
			Element componentNode = root.element("component-scan");
			Attribute attr = componentNode.attribute("base-package");
			String packStr = attr.getValue();
			// 只有一个包路径
			if (!packStr.contains(",")) {
				packages.add(packStr);
			} else { // 配置了多个路径，逗号切割
				String[] split = packStr.split(",");
				for (String s : split) {
					packages.add(s.trim());
				}
			}
		} catch (Exception e) {
			// TODO： 处理异常
			e.printStackTrace();
		}
		return packages;
	}


	private void scanComponents(String path)
	{
		File file = new File(path);
		if (file.isDirectory()) {
			for (File listFile : Objects.requireNonNull(file.listFiles())) {
				if (listFile.isDirectory()) {
					scanComponents(listFile.getAbsolutePath());
				} else {
					// 将 class 文件路径转成包访问路径
					String absolutePath = listFile.getAbsolutePath();
					// 去掉前缀路径
					absolutePath = absolutePath.substring(absolutePath.indexOf("com"));
					// 移除 .class 后缀名
					absolutePath = absolutePath.substring(0, absolutePath.indexOf("."));
					// 将文件分割符号 '/' 替换成 '.' com/spring/controller/UserController => com.spring.controller.UserController
					String className = absolutePath.replace(File.separator, ".");
					this.classFiles.add(className);
				}
			}
		}
	}

	@Override
	protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
	{
		doPost(req, resp);
	}

	@Override
	protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
	{
		try {
			System.out.println("receive a new request: " + req.getRequestURI());
			String path = req.getRequestURI();
			if (handlerMapping.containsKey(path)) {
				// 根据请求路径获请求处理 Handler（控制器）
				RequestHandler handler = handlerMapping.get(path);
				// 先确认当前请求跟控制器是否匹配
				if (!handler.match(req)) {
					resp.setStatus(404);
					resp.getWriter().println("<h1>404 Page Not Found.</h1>");
					return;
				}
				// 调用之前要实现参数注入
				List<Object> args = new ArrayList<>();
				Parameter[] parameters = handler.getMethod().getParameters();
				for (Parameter parameter : parameters) {
					if (parameter.getType().isAssignableFrom(req.getClass())) {
						// 输入 HttpServletRequest 对象
						args.add(req);
					} else {
						// 其他参数去 request 对象中拿
						args.add(req.getParameter(parameter.getName()));
					}
				}
				// 调用控制器方法，获取返回值
				Object result = handler.getMethod().invoke(handler.getController(), args.toArray());
				resp.setStatus(200);
				// 如果是字符串结果，直接输出
				if (result instanceof String) {
					// TODO: 如果有设置模板解析器，这里要处理模板渲染

					// 直接将调用返回结果输出到浏览器
					resp.getWriter().println(result);
				} else if (handler.getMethod().isAnnotationPresent(ResponseBody.class)) { // 对象则输出 json
					String json = JSON.toJSONString(result);
					resp.setContentType("application/json; charset=utf-8");
					resp.getWriter().write(json);
					resp.getWriter().flush();
					resp.getWriter().close();
				}
			} else {
				resp.setStatus(404);
				resp.getWriter().println("<h1>404 Page Not Found.</h1>");
			}
		} catch (Exception e) {
			resp.setStatus(500);
			resp.getWriter().println("<h1>500 Internal server error</h1>");
			e.printStackTrace();
		}
	}
}
